Sequence Density Estimation using Field Theory

In this notebook we will show how to infer sequence densities across large sequence space from a limited number of observations using SeqDEFT

[1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import gpmap.src.plot as plot
import gpmap.src.inference as inf

from gpmap.src.space import SequenceSpace
from gpmap.src.randwalk import WMWSWalk

from scipy.special._logsumexp import logsumexp
from scipy.stats.stats import pearsonr

Simulate landscape using VC regression

[2]:
np.random.seed(1)
seq_length = 5
n_alleles = 4

lambdas = [0] + [10**(3-i) for i in range(seq_length)]
n_seqs = 1000
lambdas
[2]:
[0, 1000, 100, 10, 1, 0.1]
[3]:
vc = inf.VCregression()
vc.init(seq_length=seq_length, n_alleles=n_alleles)
phi = vc.simulate(lambdas)['function']
[4]:
Q_real = np.exp(phi - logsumexp(phi))
counts = np.random.multinomial(n=n_seqs, pvals=Q_real)
data = pd.DataFrame({'Q_real': Q_real, 'counts': counts}, index=vc.genotypes)
print("% of possible genotypes observed at least once: {:.2f}".format((data['counts'] > 0).mean() * 100))
data.loc[data['counts'] > 0]
% of possible genotypes observed at least once: 7.03
[4]:
Q_real counts
AAACC 0.002450 6
AADCA 0.001181 1
AADCC 0.004633 4
AADCD 0.001260 2
ABACA 0.000781 1
... ... ...
DCCDA 0.003278 4
DCCDC 0.001533 1
DCCDD 0.002695 3
DCDCA 0.001202 2
DCDDD 0.000838 1

72 rows × 2 columns

Note that there are we have observed only a few sequences at least one time and actually most of the sequences have not even been observed a single time.

Run SeqDEFT inference on simulated count data

[5]:
seqdeft = inf.SeqDEFT(P=2)
inf_densities = seqdeft.fit(X=data.index.values, counts=data['counts'].values)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:04<00:00,  1.56it/s]
[6]:
inf_densities.head()
[6]:
frequency Q_star
AAAAA 0.0 4.277732e-06
AAAAB 0.0 3.085019e-07
AAAAC 0.0 2.792703e-06
AAAAD 0.0 2.103992e-06
AAABA 0.0 5.918343e-06
[7]:
data = data.join(inf_densities)
data
[7]:
Q_real counts frequency Q_star
AAAAA 3.661469e-06 0 0.0 4.277732e-06
AAAAB 5.151412e-09 0 0.0 3.085019e-07
AAAAC 4.113687e-06 0 0.0 2.792703e-06
AAAAD 1.991108e-09 0 0.0 2.103992e-06
AAABA 5.894597e-07 0 0.0 5.918343e-06
... ... ... ... ...
DDDCD 8.306988e-10 0 0.0 7.816898e-06
DDDDA 5.068515e-12 0 0.0 1.402378e-06
DDDDB 1.185211e-11 0 0.0 1.347608e-07
DDDDC 4.666807e-09 0 0.0 5.506277e-07
DDDDD 2.894802e-10 0 0.0 1.497771e-06

1024 rows × 4 columns

[8]:
fig = plot.plot_SeqDEFT_summary(seqdeft.log_Ls, inf_densities)
../_images/usage_4_SeqDEFT_11_0.png

This plots shows how the cross validated log-likelihood in held-out data evolves with the different values of the hyperparameter a tested. We can see that the optimal value is far from 0 but also is not \infty, suggesting that the optimal solution is indeed somewhere in between the maximum entropy solution and the empirical frequencies.

The second plot shows the relationship between the observed frequencies in the data and the inferred densities for each of the genotypes, showing large differences, up to an order of magnitude between the empirical frequency and the inferred density for some of them, which is achieved by smoothing the log-densities over sequence space to reduce influence of the noisy data

[9]:
fig, axes = plt.subplots(1, 1, figsize=(4, 4))

Q = data[['Q_real', 'Q_star']].values.flatten()
lims = (Q.min(), Q.max())
axes.scatter(data['Q_real'], data['Q_star'], s=5)
axes.plot(lims, lims, lw=0.5, linestyle='--', c='grey')
axes.set(xlabel=r'$Q_{real}$', ylabel=r'$Q_{inferred}$', xscale='log', yscale='log',
         xlim=lims, ylim=lims)
pearsonr(np.log(data['Q_real']), np.log(data['Q_star']))
[9]:
(0.7595841433125253, 3.9165228011520593e-193)
../_images/usage_4_SeqDEFT_13_1.png

Visualizing the inferred landscape

[13]:
space = SequenceSpace(seq_length=seq_length, n_alleles=n_alleles, alphabet_type='custom',
                      function=np.log10(inf_densities['Q_star']))
rw = WMWSWalk(space)
[14]:
plot.figure_Ns_grid(rw, fmax=-2.5)
../_images/usage_4_SeqDEFT_16_0.png
[15]:
rw.calc_visualization(mean_function=-2.5, n_components=20)
nodes_df, edges_df = rw.nodes_df, rw.space.get_edges_df()
[16]:
fig, axes = plot.init_fig(1, 1, colsize=5, rowsize=4.5)
plot.plot_relaxation_times(rw.decay_rates_df, axes)
../_images/usage_4_SeqDEFT_18_0.png
[17]:
plot.plot_interactive(nodes_df, edges_df=edges_df, z='3', nodes_size=2)
[18]:
plot.figure_allele_grid(nodes_df, edges_df=edges_df, autoscale_axis=False)
../_images/usage_4_SeqDEFT_20_0.png